-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added support for log_model='best_and_last' option in wandb logger #9356
Added support for log_model='best_and_last' option in wandb logger #9356
Conversation
for more information, see https://pre-commit.ci
Codecov Report
@@ Coverage Diff @@
## master #9356 +/- ##
======================================
- Coverage 88% 88% -0%
======================================
Files 178 178
Lines 14895 14900 +5
======================================
+ Hits 13138 13141 +3
- Misses 1757 1759 +2 |
logger.experiment.project_name.return_value = "project" | ||
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3) | ||
trainer.fit(model) | ||
assert wandb.init().log_artifact.call_count == 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this does not test, that the other ones are properly removed. Can we also test/mock this somehow?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For sure. I will figure out how this should be properly tested.
Do you think that this should be done with mocking? @justusschock
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe you should train for 5 epochs in this test
@borisdayma mind reviewing this PR? |
api = wandb.Api(overrides={"project": self.experiment.project}) | ||
|
||
for version in api.artifact_versions(f"model-{self.experiment.id}", "model"): | ||
# Clean up all versions that don't have an alias such as 'latest'. | ||
if len(version.aliases) == 0: | ||
version.delete() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
noob question. I am not familiar with wandb API.
It seems model are being versioned and your are deleting all versions which doesn't have either latest or best aliases.
I am not sure to grasp why this would save only best and last model weights.
Furthermore, I don't think this would work for multiple ModelCheckpoint. Should we save the monitor as metadata to perform the filtering.
best_and_last should produce at maximum num_model_checkpoints + 1 checkpoints right ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We automatically tag some artifact versions (model checkpoints). We tag the "latest" and we tag the "best" when monitoring value is defined (they can point to the same model). So there are 2 versions tagged at most.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@borisdayma Yah exactly. The implementation relies on the fact that there are at most 2 live aliases at the same time.
Several aliases can probably be included (best_0, best_1, best_2, ..., latest), but this is another PR :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe best_{monitor_0}, best_{monitor_1}, best_{monitor_2} would be better, it would enable users to navigate their weights better on the Wandb UI.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The idea is that we directly leverage ModelCheckpoint
to identify best metrics (easier to maintain the callback, avoid replicating the same logic, and maybe easier for users).
You can see an example here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. I'm just curious of the naming and if we could offer this functionality in a cleaner way.
An alternative could be to just add an argument such as "remove_unused_artifacts" to the logger which would be false by default.
Then we can just have a separate function for this purpose. What do you think?
Great idea @borisdayma to make the naming cleaner. However, I think that we need to take this a step further. So we would pass Let's examine the current functionality of
Therefore, I suggest:
What do you think? |
So my concern was related to people with limited network bandwidth, whether upload speed or max quota before being throttled. I have no issue with changing arguments name but then it requires backward compatibility so better to avoid it unless it's clearly beneficial. |
logger = WandbLogger(log_model="best_and_last") | ||
logger.experiment.id = "1" | ||
logger.experiment.project_name.return_value = "project" | ||
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this test use a ModelCheckpoint ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup. We should definitely build a better test. I am just not familiar with the whole "Mocking" thing, and with tests in general. Should my test function be also decorated with mocking?
If not, then this requires some wandb default setup to run the experiment. Not sure how to make this right as I never wrote tests for such a large project like Lightning :)
Ok so I really like the purpose of this PR. I think there are a few things to decide:
|
Regarding the network bandwidth, I agree that this might be a problem to some people. How about this:
What do you think? @borisdayma @tchaton As in all engineering, there will be a tradeoff eventually... |
Hey @borisdayma, quick question about Wandb internals. Are the model uploaded within a separated thread / process ? How would be the impact of training to log all models ? Best, |
It is performed in a separate process so should not impact training at all (even if the wandb process were to fail).
That's a great idea! In addition I suggest:
|
Hey @borisdayma, @ohayonguy, Yes, I like the proposal too. Should we adapt this PR directly with the suggested changes ? Best, |
As you guys want! I can work on it later this week unless @ohayonguy wants to update this PR? |
Hey guys @borisdayma @tchaton |
This pull request has been automatically marked as stale because it has not had recent activity. It will be closed in 7 days if no further activity occurs. If you need further help see our docs: https://pytorch-lightning.readthedocs.io/en/latest/generated/CONTRIBUTING.html#pull-request or ask the assistance of a core contributor here or on Slack. Thank you for your contributions. |
This pull request is going to be closed. Please feel free to reopen it create a new from the actual master. |
Hi everyone 👋🏻! Thanks for your help! |
Any updates on this? Will this feature be merged? |
+1 for this feature. Anyone wanting to train a large model will be interested in this. Due to the large model size keeping a full history of checkpoints will take up too much disk space but waiting for end of training to upload top k model checkpoints is too vulnerable to crashes. |
What does this PR do?
As a followup of the discussion #9342 (reply in thread)
Currently, the Weights and Biases logger is able to log checkpoints as artifacts either at the end of training or during training (any time a new checkpoint is created by the ModelCheckpoint callback). This feature can be controlled with the
log_model
flag, wherelog_model=True
means to log at the end of training, andlog_model=all
means to log all checkpoints during training. However, the latter option can blow up the artifacts storage as many checkpoints can be saved, which is undesirable especially when using wandb localThus, I added a
log_model=best_and_last
option, which keeps only the best and the last checkpoints in the wandb artifacts storage. Previous artifact versions are automatically deleted.No extra dependencies are required. Only wandb.
Does your PR introduce any breaking changes? If yes, please list them.
None
Before submitting
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:
Did you have fun? yup :)
Make sure you had fun coding 🙃